Histopathologic Cancer Detection: Problem and Data Description¶
This notebook provides an introduction to the Histopathologic Cancer Detection problem and a description of the dataset used in this project.
1. Problem Description¶
1.1 What is Histopathologic Cancer Detection?¶
Histopathology is the microscopic examination of tissue samples to study the manifestations of disease. Pathologists examine tissue sections under a microscope to identify abnormal cells and patterns that may indicate cancer or other diseases. This process is crucial for cancer diagnosis, staging, and treatment planning.
In the Histopathologic Cancer Detection challenge, the goal is to create an algorithm to identify metastatic cancer in small image patches taken from larger digital pathology scans. Specifically, we are tasked with determining whether the center 32x32px region of a 96x96px patch contains at least one pixel of tumor tissue.
As described in the Kaggle competition:
In this competition, you must create an algorithm to identify metastatic cancer in small image patches taken from larger digital pathology scans. The data for this competition is a slightly modified version of the PatchCamelyon (PCam) benchmark dataset (the original PCam dataset contains duplicate images due to its probabilistic sampling, however, the version presented on Kaggle does not contain duplicates).
1.2 The Challenge of Metastatic Cancer Detection¶
Metastatic cancer occurs when cancer cells break away from the primary tumor site, travel through the blood or lymphatic system, and form new tumors in other parts of the body. Detecting these metastases is critical for:
- Accurate cancer staging: The presence and extent of metastases determine the cancer stage, which guides treatment decisions.
- Treatment planning: Different treatment approaches may be needed for metastatic versus non-metastatic cancer.
- Prognosis assessment: Metastatic cancer generally has a poorer prognosis than localized cancer.
However, manual examination of histopathology slides is:
- Time-consuming: Pathologists must carefully examine large tissue samples
- Subject to variability: Different pathologists may interpret the same slide differently
- Prone to human error: Fatigue and workload can affect accuracy
Automated detection systems using machine learning can help address these challenges by providing consistent, rapid, and potentially more accurate assessments.
2. Medical Importance¶
2.1 Why is This Problem Important?¶
Developing accurate automated systems for histopathologic cancer detection has several significant benefits:
Clinical Benefits¶
- Improved diagnostic accuracy: Reducing human error and variability in interpretation
- Faster diagnosis: Accelerating the diagnostic process, potentially leading to earlier treatment
- Reduced workload for pathologists: Allowing them to focus on more complex cases
- Standardization: Providing consistent diagnostic criteria across different healthcare settings
Research Benefits¶
- Quantitative analysis: Enabling more precise measurement of disease characteristics
- Large-scale studies: Facilitating research on large datasets that would be impractical to analyze manually
- Novel biomarker discovery: Potentially identifying subtle patterns associated with disease outcomes
Global Health Impact¶
- Addressing pathologist shortages: Many regions worldwide have insufficient pathology expertise
- Telemedicine support: Enhancing remote diagnostic capabilities
- Democratizing expertise: Making high-quality diagnostic support more widely available
2.2 Current State of Automated Histopathology Analysis¶
Recent advances in deep learning have shown promising results in histopathology image analysis. Several studies have demonstrated that convolutional neural networks (CNNs) can achieve performance comparable to or even exceeding that of pathologists in specific diagnostic tasks.
However, challenges remain in:
- Generalizability across different laboratories and staining protocols
- Interpretability of model decisions
- Integration into clinical workflows
- Regulatory approval for clinical use
This project contributes to the ongoing effort to improve automated histopathologic analysis by developing and evaluating models for metastatic cancer detection.
3. Dataset Description¶
3.1 Dataset Overview¶
The dataset for this project comes from the PatchCamelyon (PCam) benchmark dataset, which is derived from the Camelyon16 challenge. It consists of histopathologic scans of lymph node sections, where the task is to identify metastatic tissue.
As described in the Kaggle competition:
PCam is highly interesting for both its size, simplicity to get started on, and approachability. In the authors' words:
[PCam] packs the clinically-relevant task of metastasis detection into a straight-forward binary image classification task, akin to CIFAR-10 and MNIST. Models can easily be trained on a single GPU in a couple hours, and achieve competitive scores in the Camelyon16 tasks of tumor detection and whole-slide image diagnosis. Furthermore, the balance between task-difficulty and tractability makes it a prime suspect for fundamental machine learning research on topics as active learning, model uncertainty, and explainability.
The PCam dataset was created to serve as a benchmark for machine learning algorithms in medical image analysis, providing a more accessible format than the original whole-slide images. The version used in the Kaggle competition is slightly modified from the original PCam dataset in that it does not contain duplicate images that were present in the original due to probabilistic sampling.
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image
# Set paths to the dataset
# Note: Update these paths based on your local setup
BASE_DIR = '../data'
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
TEST_DIR = os.path.join(BASE_DIR, 'test')
TRAIN_LABELS_PATH = os.path.join(BASE_DIR, 'train_labels.csv')
# Load the training labels
try:
train_labels = pd.read_csv(TRAIN_LABELS_PATH)
print(f"Successfully loaded training labels with shape: {train_labels.shape}")
print("\nSample of training labels:")
display(train_labels.head())
except Exception as e:
print(f"Error loading training labels: {e}")
print("Please ensure the dataset is downloaded and the paths are correctly set.")
Successfully loaded training labels with shape: (220025, 2) Sample of training labels:
| id | label | |
|---|---|---|
| 0 | f38a6374c348f90b587e046aac6079959adf3835 | 0 |
| 1 | c18f2d887b7ae4f6742ee445113fa1aef383ed77 | 1 |
| 2 | 755db6279dae599ebb4d39a9123cce439965282d | 0 |
| 3 | bc3f0c64fb968ff4a8bd33af6971ecae77c75e08 | 0 |
| 4 | 068aba587a4950175d04c680d38943fd488d6a9d | 0 |
3.2 Dataset Characteristics¶
Image Properties¶
- Image format: The images are provided in standard image formats (e.g., JPEG, PNG)
- Image dimensions: Each patch is 96×96 pixels
- Color channels: RGB (3 channels)
- Resolution: The images represent a 96×96μm area of the original slide (at 10× magnification)
- Task: Determine whether the center 32×32px region contains at least one pixel of tumor tissue
Dataset Size¶
- Training set: Approximately 220,000 images
- Test set: Approximately 57,500 images
Class Distribution¶
The dataset has two classes:
- Class 0: Normal tissue (no metastatic cancer)
- Class 1: Metastatic cancer tissue
Let's examine the class distribution in the training set:
# Analyze class distribution
try:
class_distribution = train_labels['label'].value_counts()
print("Class distribution in training set:")
print(class_distribution)
# Calculate percentages
class_percentages = class_distribution / len(train_labels) * 100
print("\nClass distribution percentages:")
for label, percentage in class_percentages.items():
print(f"Class {label}: {percentage:.2f}%")
# Visualize class distribution
plt.figure(figsize=(8, 6))
plt.bar(['Normal Tissue (0)', 'Metastatic Cancer (1)'], class_distribution.values)
plt.title('Class Distribution in Training Set')
plt.ylabel('Number of Samples')
plt.grid(axis='y', alpha=0.3)
for i, count in enumerate(class_distribution.values):
plt.text(i, count + 500, f"{count}\n({class_percentages.values[i]:.1f}%)",
ha='center', va='bottom')
plt.show()
except Exception as e:
print(f"Error analyzing class distribution: {e}")
print("This will be completed once the dataset is loaded.")
Class distribution in training set: label 0 130908 1 89117 Name: count, dtype: int64 Class distribution percentages: Class 0: 59.50% Class 1: 40.50%
3.3 Data Source and Context¶
Origin of the Data¶
The PCam dataset is derived from the Camelyon16 challenge, which was organized to evaluate algorithms for detecting metastatic breast cancer in lymph node tissue. The original Camelyon16 dataset consists of whole-slide images (WSIs) of lymph node sections from two medical centers in the Netherlands:
- Radboud University Medical Center (Nijmegen)
- University Medical Center Utrecht
The slides were digitized using a 40× objective lens, resulting in a pixel resolution of 0.243 microns. For the PCam dataset, patches of 96×96 pixels were extracted from these WSIs.
Medical Context¶
In breast cancer staging, examining lymph nodes for metastases is a critical step. Sentinel lymph node biopsy (SLNB) is a procedure where the first lymph node(s) to which cancer is likely to spread is removed and examined. The presence of metastatic cells in these lymph nodes indicates that the cancer has begun to spread beyond the primary tumor site.
The traditional workflow for lymph node examination involves:
- Surgical removal of lymph nodes
- Preparation of tissue sections (fixing, embedding, sectioning, staining)
- Microscopic examination by a pathologist
- Diagnosis based on the presence or absence of cancer cells
This dataset represents step 3 in this process, where tissue sections have been digitized for analysis.
Staining Technique¶
The tissue sections in this dataset are stained with Hematoxylin and Eosin (H&E), which is the standard staining method in histopathology:
- Hematoxylin: Stains cell nuclei blue/purple
- Eosin: Stains cytoplasm and extracellular matrix pink/red
This staining allows pathologists (and our algorithms) to distinguish different cellular structures and identify abnormal patterns characteristic of cancer.
3.4 Visualization of Sample Images¶
Let's visualize some sample images from both classes to better understand the data:
def load_and_display_samples(num_samples=4):
"""Load and display sample images from each class"""
try:
# Get sample IDs for each class
normal_samples = train_labels[train_labels['label'] == 0]['id'].sample(num_samples).values
cancer_samples = train_labels[train_labels['label'] == 1]['id'].sample(num_samples).values
# Set up the figure
fig, axes = plt.subplots(2, num_samples, figsize=(num_samples*3, 6))
# Display normal tissue samples
for i, sample_id in enumerate(normal_samples):
img_path = os.path.join(TRAIN_DIR, f"{sample_id}.tif")
img = Image.open(img_path)
axes[0, i].imshow(img)
axes[0, i].set_title(f"Normal Tissue\nID: {sample_id}")
axes[0, i].axis('off')
# Display cancer tissue samples
for i, sample_id in enumerate(cancer_samples):
img_path = os.path.join(TRAIN_DIR, f"{sample_id}.tif")
img = Image.open(img_path)
axes[1, i].imshow(img)
axes[1, i].set_title(f"Metastatic Cancer\nID: {sample_id}")
axes[1, i].axis('off')
plt.tight_layout()
plt.suptitle("Sample Images from the Dataset", fontsize=16, y=1.05)
plt.show()
except Exception as e:
print(f"Error displaying sample images: {e}")
print("This will be completed once the dataset is loaded.")
# Display sample images
load_and_display_samples()
4. Evaluation Metric¶
4.1 Competition Metric¶
The primary evaluation metric for this competition is the Area Under the ROC Curve (AUC-ROC). This metric is particularly suitable for binary classification problems, especially when dealing with medical diagnoses.
As stated in the Kaggle competition:
Submissions are evaluated on area under the ROC curve between the predicted probability and the observed target.
Why AUC-ROC?¶
- Threshold-independent: AUC-ROC evaluates the model's performance across all possible classification thresholds, not just at a single decision threshold.
- Balanced assessment: It works well even with imbalanced datasets, as it considers both sensitivity (true positive rate) and specificity (true negative rate).
- Interpretability: The AUC value represents the probability that the model ranks a random positive example higher than a random negative example.
Interpretation of AUC-ROC Values:¶
- AUC = 0.5: The model has no discriminative ability (equivalent to random guessing)
- 0.5 < AUC < 0.7: Poor discrimination
- 0.7 ≤ AUC < 0.8: Acceptable discrimination
- 0.8 ≤ AUC < 0.9: Excellent discrimination
- AUC ≥ 0.9: Outstanding discrimination
In the medical context of cancer detection, a high AUC is particularly important as it indicates the model's ability to correctly distinguish between cancerous and non-cancerous tissue, minimizing both false positives and false negatives.
4.2 Submission Format¶
For the Kaggle competition, submissions must be in a specific format. As described in the competition:
For each id in the test set, you must predict a probability that center 32x32px region of a patch contains at least one pixel of tumor tissue. The file should contain a header and have the following format:
id,label 0b2ea2a822ad23fdb1b5dd26653da899fbd2c0d5,0 95596b92e5066c5c52466c90b69ff089b39f2737,0 248e6738860e2ebcf6258cdc1f32f299e0c76914,0 etc.
This format requires us to output a probability (between 0 and 1) for each image in the test set, indicating the likelihood that the center region contains tumor tissue.
5. Challenges and Considerations¶
5.1 Technical Challenges¶
Several technical challenges are associated with histopathologic image analysis:
- Visual complexity: Histopathology images contain complex patterns and structures that can be difficult to interpret
- Staining variations: Differences in staining protocols and digitization can affect image appearance
- Contextual information: The 96×96 pixel patches may lack broader contextual information from the whole slide
- Subtle differences: The visual differences between normal and cancerous tissue can be subtle and require expert knowledge to identify
- Computational requirements: Processing and analyzing large numbers of high-resolution images requires significant computational resources
5.2 Clinical Considerations¶
From a clinical perspective, several factors are important to consider:
- False negatives: Missing cancer (false negatives) can lead to delayed treatment and poorer outcomes
- False positives: Incorrectly identifying normal tissue as cancerous (false positives) can lead to unnecessary treatments and patient anxiety
- Interpretability: Clinicians need to understand why the model made a particular prediction
- Integration: How the model would integrate into existing clinical workflows
- Regulatory approval: Requirements for using AI systems in clinical practice
These considerations will guide our approach to model development and evaluation.
6. Citations and Acknowledgements¶
6.1 Dataset Citations¶
As noted in the Kaggle competition, if you use PCam in a scientific publication, please reference the following papers:
B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, M. Welling. "Rotation Equivariant CNNs for Digital Pathology". arXiv:1806.03962
Ehteshami Bejnordi et al. Diagnostic Assessment of Deep Learning Algorithms for Detection of Lymph Node Metastases in Women With Breast Cancer. JAMA: The Journal of the American Medical Association, 318(22), 2199–2210. doi:jama.2017.14585
6.2 Competition Citation¶
Will Cukierski. Histopathologic Cancer Detection. https://kaggle.com/competitions/histopathologic-cancer-detection, 2018. Kaggle.
6.3 Acknowledgements¶
The Kaggle competition acknowledges the following contributors:
Kaggle is hosting this competition for the machine learning community to use for fun and practice. This dataset was provided by Bas Veeling, with additional input from Babak Ehteshami Bejnordi, Geert Litjens, and Jeroen van der Laak.
You may view and download the official Pcam dataset from GitHub. The data is provided under the CC0 License, following the license of Camelyon16.
7. Summary¶
In this notebook, we have:
- Introduced the problem of histopathologic cancer detection and its importance in cancer diagnosis and treatment
- Described the dataset used for this project, including its characteristics, source, and medical context
- Visualized sample images from both classes (normal and metastatic cancer tissue)
- Discussed the evaluation metric (AUC-ROC) and its relevance to this medical classification task
- Outlined the challenges and considerations associated with this problem
- Provided citations and acknowledgements for the dataset and competition
In the next notebook (02_Exploratory_Data_Analysis), we will perform a more detailed analysis of the dataset, including:
- Comprehensive visualization of the data
- Statistical analysis of image characteristics
- Investigation of potential data quality issues
- Identification of patterns that might inform our modeling approach
Exploratory Data Analysis
In this notebook, we'll perform a comprehensive exploratory data analysis of the Histopathologic Cancer Detection dataset. We'll examine the characteristics of the dataset, visualize examples from each class, analyze the distribution of pixel values, check for data quality issues, and identify potential preprocessing needs.
1. Load Libraries¶
# Load libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
import random
from tqdm.notebook import tqdm
import cv2
# Set plot style
sns.set_style("whitegrid")
# For reproducibility
np.random.seed(42)
random.seed(42)
2. Load Data¶
First, we'll load the dataset and examine its structure.
# Define paths
# Update these paths to match your directory structure
data_dir = '../data'
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
train_labels_path = os.path.join(data_dir, 'train_labels.csv')
# Check if the paths exist
print(f"Train directory exists: {os.path.exists(train_dir)}")
print(f"Test directory exists: {os.path.exists(test_dir)}")
print(f"Train labels file exists: {os.path.exists(train_labels_path)}")
Train directory exists: True Test directory exists: True Train labels file exists: True
# Load train labels
try:
train_labels = pd.read_csv(train_labels_path)
print(f"Train labels shape: {train_labels.shape}")
print("\nFirst few rows of train_labels:")
display(train_labels.head())
except Exception as e:
print(f"Error loading train labels: {e}")
print("Please make sure the train_labels.csv file is in the correct location.")
Train labels shape: (220025, 2) First few rows of train_labels:
| id | label | |
|---|---|---|
| 0 | f38a6374c348f90b587e046aac6079959adf3835 | 0 |
| 1 | c18f2d887b7ae4f6742ee445113fa1aef383ed77 | 1 |
| 2 | 755db6279dae599ebb4d39a9123cce439965282d | 0 |
| 3 | bc3f0c64fb968ff4a8bd33af6971ecae77c75e08 | 0 |
| 4 | 068aba587a4950175d04c680d38943fd488d6a9d | 0 |
3. Analyze Class Distribution¶
Let's examine the distribution of cancer vs. non-cancer samples in the dataset.
# Analyze class distribution
try:
class_distribution = train_labels['label'].value_counts()
print("Class distribution:")
print(class_distribution)
print(f"Percentage of positive samples: {class_distribution[1] / len(train_labels) * 100:.2f}%")
# Visualize class distribution
plt.figure(figsize=(10, 6))
sns.countplot(x='label', data=train_labels, palette=['skyblue', 'salmon'])
plt.title('Class Distribution: Cancer vs. Non-Cancer', fontsize=16)
plt.xlabel('Label (0: Non-Cancer, 1: Cancer)', fontsize=14)
plt.ylabel('Count', fontsize=14)
plt.xticks([0, 1], ['Non-Cancer (0)', 'Cancer (1)'])
plt.grid(True, alpha=0.3)
plt.show()
# Pie chart
plt.figure(figsize=(8, 8))
plt.pie(class_distribution,
labels=['Non-Cancer', 'Cancer'],
autopct='%1.1f%%',
colors=['skyblue', 'salmon'],
explode=[0, 0.1],
shadow=True,
startangle=90)
plt.title('Class Distribution: Cancer vs. Non-Cancer', fontsize=16)
plt.axis('equal')
plt.show()
except Exception as e:
print(f"Error analyzing class distribution: {e}")
Class distribution: label 0 130908 1 89117 Name: count, dtype: int64 Percentage of positive samples: 40.50%
/var/folders/0b/kb08wkgd4zs5gc_z02svm2xh0000gn/T/ipykernel_55711/791591715.py:10: FutureWarning: Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect. sns.countplot(x='label', data=train_labels, palette=['skyblue', 'salmon'])
4. Visualize Sample Images¶
Let's visualize some sample images from each class to understand what we're working with.
def load_image(image_id, directory=train_dir):
"""Load an image from the specified directory."""
try:
img_path = os.path.join(directory, f"{image_id}.tif")
img = Image.open(img_path)
return np.array(img)
except Exception as e:
print(f"Error loading image {image_id}: {e}")
return None
def display_images(image_ids, labels, title, rows=2, cols=5):
"""Display a grid of images with their labels."""
plt.figure(figsize=(cols*3, rows*3))
for i, (img_id, label) in enumerate(zip(image_ids, labels)):
if i >= rows*cols:
break
img = load_image(img_id)
if img is None:
continue
plt.subplot(rows, cols, i+1)
plt.imshow(img)
plt.title(f"Label: {label} ({'Cancer' if label == 1 else 'Non-Cancer'})")
plt.axis('off')
plt.suptitle(title, fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
try:
# Get sample images from each class
cancer_samples = train_labels[train_labels['label'] == 1].sample(10)
non_cancer_samples = train_labels[train_labels['label'] == 0].sample(10)
# Display cancer samples
display_images(cancer_samples['id'], cancer_samples['label'], 'Cancer Samples (Label = 1)')
# Display non-cancer samples
display_images(non_cancer_samples['id'], non_cancer_samples['label'], 'Non-Cancer Samples (Label = 0)')
except Exception as e:
print(f"Error displaying sample images: {e}")
5. Analyze Image Properties¶
Let's examine the properties of the images, such as dimensions and color channels.
try:
# Get a sample image
sample_id = train_labels['id'].iloc[0]
sample_img = load_image(sample_id)
if sample_img is not None:
print(f"Image shape: {sample_img.shape}")
print(f"Image data type: {sample_img.dtype}")
print(f"Min pixel value: {sample_img.min()}")
print(f"Max pixel value: {sample_img.max()}")
print(f"Mean pixel value: {sample_img.mean():.2f}")
print(f"Standard deviation: {sample_img.std():.2f}")
except Exception as e:
print(f"Error analyzing image properties: {e}")
Image shape: (96, 96, 3) Image data type: uint8 Min pixel value: 0 Max pixel value: 255 Mean pixel value: 232.89 Standard deviation: 34.03
6. Analyze Pixel Value Distribution¶
Let's analyze the distribution of pixel values in the images.
try:
# Load a sample of images from each class
n_samples = 20 # Number of samples to analyze
# Cancer samples
cancer_ids = train_labels[train_labels['label'] == 1].sample(n_samples)['id']
cancer_images = [load_image(img_id) for img_id in cancer_ids]
cancer_images = [img for img in cancer_images if img is not None]
# Non-cancer samples
non_cancer_ids = train_labels[train_labels['label'] == 0].sample(n_samples)['id']
non_cancer_images = [load_image(img_id) for img_id in non_cancer_ids]
non_cancer_images = [img for img in non_cancer_images if img is not None]
# Flatten all pixel values
cancer_pixels = np.concatenate([img.flatten() for img in cancer_images])
non_cancer_pixels = np.concatenate([img.flatten() for img in non_cancer_images])
# Plot histograms
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(cancer_pixels, bins=50, alpha=0.7, color='salmon', label='Cancer')
plt.title('Pixel Value Distribution - Cancer Samples', fontsize=14)
plt.xlabel('Pixel Value', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend()
plt.subplot(1, 2, 2)
plt.hist(non_cancer_pixels, bins=50, alpha=0.7, color='skyblue', label='Non-Cancer')
plt.title('Pixel Value Distribution - Non-Cancer Samples', fontsize=14)
plt.xlabel('Pixel Value', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
except Exception as e:
print(f"Error analyzing pixel distributions: {e}")
7. Check for Data Quality Issues¶
Let's check for potential data quality issues, such as corrupted images, outliers, or inconsistencies.
try:
# Sample a subset of images
n_samples = 100
sample_ids = train_labels['id'].sample(n_samples).tolist()
# Initialize counters and lists for statistics
n_loaded = 0
n_failed = 0
shapes = []
means = []
stds = []
failed_ids = []
# Check each image
for img_id in tqdm(sample_ids, desc="Checking images"):
try:
img = load_image(img_id)
if img is None:
n_failed += 1
failed_ids.append(img_id)
continue
n_loaded += 1
shapes.append(img.shape)
means.append(img.mean())
stds.append(img.std())
except Exception as e:
n_failed += 1
failed_ids.append(img_id)
print(f"Error processing image {img_id}: {e}")
# Report results
print(f"Checked {n_loaded + n_failed} images:")
print(f" - Successfully loaded: {n_loaded} ({n_loaded/(n_loaded+n_failed)*100:.2f}%)")
print(f" - Failed to load: {n_failed} ({n_failed/(n_loaded+n_failed)*100:.2f}%)")
if n_loaded > 0:
# Check for shape consistency
unique_shapes = set(str(s) for s in shapes)
print(f"\nImage shapes: {len(unique_shapes)} unique shape(s)")
for shape in unique_shapes:
count = sum(1 for s in shapes if str(s) == shape)
print(f" - {shape}: {count} images ({count/n_loaded*100:.2f}%)")
except Exception as e:
print(f"Error checking image quality: {e}")
Checking images: 0%| | 0/100 [00:00<?, ?it/s]
Checked 100 images: - Successfully loaded: 100 (100.00%) - Failed to load: 0 (0.00%) Image shapes: 1 unique shape(s) - (96, 96, 3): 100 images (100.00%)
8. Analyze Image Complexity¶
Let's analyze the complexity of the images to better understand what distinguishes cancer from non-cancer samples.
def extract_image_features(img):
"""Extract basic features from an image."""
# Convert to grayscale
if len(img.shape) == 3:
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
else:
gray = img
# Calculate basic statistics
mean = gray.mean()
std = gray.std()
# Edge detection using Canny
edges = cv2.Canny(gray, 100, 200)
edge_density = edges.sum() / (edges.shape[0] * edges.shape[1])
return {
'mean': mean,
'std': std,
'edge_density': edge_density
}
try:
# Extract features from a sample of images
n_samples = 50 # Number of samples per class
# Cancer samples
cancer_ids = train_labels[train_labels['label'] == 1].sample(n_samples)['id']
cancer_images = [load_image(img_id) for img_id in cancer_ids]
cancer_images = [img for img in cancer_images if img is not None]
cancer_features = [extract_image_features(img) for img in cancer_images]
# Non-cancer samples
non_cancer_ids = train_labels[train_labels['label'] == 0].sample(n_samples)['id']
non_cancer_images = [load_image(img_id) for img_id in non_cancer_ids]
non_cancer_images = [img for img in non_cancer_images if img is not None]
non_cancer_features = [extract_image_features(img) for img in non_cancer_images]
# Compare edge density between classes
cancer_edge_density = [f['edge_density'] for f in cancer_features]
non_cancer_edge_density = [f['edge_density'] for f in non_cancer_features]
plt.figure(figsize=(10, 6))
plt.hist(cancer_edge_density, bins=20, alpha=0.7, color='salmon', label='Cancer')
plt.hist(non_cancer_edge_density, bins=20, alpha=0.7, color='skyblue', label='Non-Cancer')
plt.title('Distribution of Edge Density', fontsize=14)
plt.xlabel('Edge Density', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# Print average values
print(f"Cancer samples - Average edge density: {np.mean(cancer_edge_density):.4f}")
print(f"Non-cancer samples - Average edge density: {np.mean(non_cancer_edge_density):.4f}")
except Exception as e:
print(f"Error analyzing image complexity: {e}")
Cancer samples - Average edge density: 81.0549 Non-cancer samples - Average edge density: 65.4179
9. Visualize Color Channels¶
Let's examine the color channels of the images to see if there are any patterns or differences between cancer and non-cancer samples.
try:
# Get a sample image from each class
cancer_id = train_labels[train_labels['label'] == 1].sample(1)['id'].iloc[0]
non_cancer_id = train_labels[train_labels['label'] == 0].sample(1)['id'].iloc[0]
cancer_img = load_image(cancer_id)
non_cancer_img = load_image(non_cancer_id)
if cancer_img is not None and non_cancer_img is not None:
# Display cancer image channels
plt.figure(figsize=(15, 6))
plt.suptitle('Color Channels - Cancer Sample', fontsize=16)
plt.subplot(1, 4, 1)
plt.imshow(cancer_img)
plt.title('Original')
plt.axis('off')
plt.subplot(1, 4, 2)
plt.imshow(cancer_img[:,:,0], cmap='Reds')
plt.title('Red Channel')
plt.axis('off')
plt.subplot(1, 4, 3)
plt.imshow(cancer_img[:,:,1], cmap='Greens')
plt.title('Green Channel')
plt.axis('off')
plt.subplot(1, 4, 4)
plt.imshow(cancer_img[:,:,2], cmap='Blues')
plt.title('Blue Channel')
plt.axis('off')
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
except Exception as e:
print(f"Error visualizing color channels: {e}")
10. Summary of Findings and Preprocessing Recommendations¶
Key Findings:¶
Dataset Composition: The dataset consists of histopathologic images with a binary classification task (cancer vs. non-cancer).
Class Distribution: The dataset has an imbalanced class distribution, with more non-cancer samples than cancer samples.
Image Properties: All images have the same dimensions (96x96 pixels) and are in RGB format.
Pixel Value Distribution: The pixel values are distributed across the full range (0-255), with some differences in distribution between cancer and non-cancer samples.
Image Complexity: Cancer samples tend to have different edge density patterns compared to non-cancer samples, which could be a useful feature for classification.
Color Channels: The different color channels contain varying information, with some channels potentially being more informative than others for distinguishing between classes.
Preprocessing Recommendations:¶
Normalization: Normalize pixel values to the range [0, 1] by dividing by 255. This will help with model convergence and stability.
Data Augmentation: Apply data augmentation techniques to increase the diversity of the training data and help prevent overfitting. Suitable augmentations for histopathology images include:
- Random rotations (90°, 180°, 270°)
- Random flips (horizontal and vertical)
- Small random brightness and contrast adjustments
- Small random zooms
Class Balancing: Address the class imbalance using techniques such as:
- Oversampling the minority class (cancer samples)
- Using class weights in the loss function
- Implementing balanced batch sampling
Color Space Exploration: Consider exploring different color spaces (RGB, HSV, LAB) or using specific color channels that might better highlight the differences between cancer and non-cancer samples.
Feature Extraction: Consider extracting features like edge density, texture features, or color histograms as additional inputs to the model or for creating ensemble models.
Image Standardization: Standardize images by subtracting the mean and dividing by the standard deviation to center the data around zero with unit variance.
Train-Validation Split: Implement a stratified train-validation split to ensure that both sets have similar class distributions.
Histopathologic Cancer Detection: Model Architecture¶
In this notebook, we'll design, implement, and train models for the Histopathologic Cancer Detection task. We'll explore different architectures, preprocessing techniques, and training strategies to develop effective models for identifying metastatic cancer in histopathology images.
1. Setup and Data Loading¶
First, let's import the necessary libraries and load the dataset.
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import cv2
from PIL import Image
# TensorFlow and Keras
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50, EfficientNetB0, VGG16, MobileNetV2
# Scikit-learn
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report
# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)
tf.random.set_seed(42)
# Set plot style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
# Check if GPU is available
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("TensorFlow version:", tf.__version__)
Num GPUs Available: 0 TensorFlow version: 2.16.2
# Set paths to the dataset
BASE_DIR = '../data'
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
TEST_DIR = os.path.join(BASE_DIR, 'test')
MODELS_DIR = os.path.join(BASE_DIR, 'models')
TRAIN_LABELS_PATH = os.path.join(BASE_DIR, 'train_labels.csv')
# Load the training labels
try:
train_labels = pd.read_csv(TRAIN_LABELS_PATH)
print(f"Successfully loaded training labels with shape: {train_labels.shape}")
print("\nSample of training labels:")
display(train_labels.head())
# Check class distribution
class_distribution = train_labels['label'].value_counts().sort_index()
print("\nClass distribution:")
for label, count in class_distribution.items():
print(f"Class {label} ({'Normal' if label == 0 else 'Metastatic Cancer'}): {count} ({count/len(train_labels)*100:.2f}%)")
except Exception as e:
print(f"Error loading training labels: {e}")
print("Please ensure the dataset is downloaded and the paths are correctly set.")
Successfully loaded training labels with shape: (220025, 2) Sample of training labels:
| id | label | |
|---|---|---|
| 0 | f38a6374c348f90b587e046aac6079959adf3835 | 0 |
| 1 | c18f2d887b7ae4f6742ee445113fa1aef383ed77 | 1 |
| 2 | 755db6279dae599ebb4d39a9123cce439965282d | 0 |
| 3 | bc3f0c64fb968ff4a8bd33af6971ecae77c75e08 | 0 |
| 4 | 068aba587a4950175d04c680d38943fd488d6a9d | 0 |
Class distribution: Class 0 (Normal): 130908 (59.50%) Class 1 (Metastatic Cancer): 89117 (40.50%)
2. Data Preprocessing and Augmentation¶
Before we design our models, let's establish our data preprocessing and augmentation pipeline. Proper preprocessing is crucial for achieving good performance in deep learning models, especially for medical imaging tasks.
2.1 Image Loading and Normalization¶
First, we need to load the images and normalize their pixel values. Normalization helps the model converge faster and achieve better performance by ensuring that all input features are on a similar scale.
def load_and_preprocess_image(image_id, directory, target_size=(96, 96), normalize=True):
"""Load and preprocess an image from the specified directory"""
try:
img_path = os.path.join(directory, f"{image_id}.tif")
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
# Resize if needed
if img.shape[:2] != target_size:
img = cv2.resize(img, target_size)
# Normalize pixel values to [0, 1]
if normalize:
img = img.astype(np.float32) / 255.0
return img
except Exception as e:
print(f"Error loading image {image_id}: {e}")
return None
2.2 Train/Validation/Test Split¶
We'll split our data into training, validation, and test sets. The training set is used to train the model, the validation set is used to tune hyperparameters and monitor performance during training, and the test set is used for final evaluation.
# Split the data into training and validation sets
train_df, val_df = train_test_split(
train_labels,
test_size=0.2, # 20% for validation
random_state=42,
stratify=train_labels['label'] # Ensure class balance in both sets
)
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
# Check class distribution in the splits
print("\nClass distribution in training set:")
train_class_dist = train_df['label'].value_counts(normalize=True) * 100
for label, percentage in train_class_dist.items():
print(f"Class {label}: {percentage:.2f}%")
print("\nClass distribution in validation set:")
val_class_dist = val_df['label'].value_counts(normalize=True) * 100
for label, percentage in val_class_dist.items():
print(f"Class {label}: {percentage:.2f}%")
Training set size: 176020 Validation set size: 44005 Class distribution in training set: Class 0: 59.50% Class 1: 40.50% Class distribution in validation set: Class 0: 59.50% Class 1: 40.50%
2.3 Data Augmentation¶
Data augmentation is a technique to artificially expand the training dataset by creating modified versions of the existing images. This helps prevent overfitting and improves the model's ability to generalize to new data.
For histopathology images, appropriate augmentation techniques include:
- Rotation: Cells and tissue structures can appear in any orientation
- Flipping: Horizontal and vertical flips don't change the medical interpretation
- Slight zooming: Simulates variations in magnification
- Brightness/contrast adjustments: Accounts for staining variations
- Slight shifts: Helps the model focus on the relevant features regardless of their exact position
We'll use Keras' ImageDataGenerator to apply these augmentations during training.
# Define data augmentation for training
train_datagen = ImageDataGenerator(
rescale=1./255, # Normalize pixel values
rotation_range=90, # Random rotations up to 90 degrees
width_shift_range=0.1, # Random horizontal shifts
height_shift_range=0.1, # Random vertical shifts
shear_range=0.1, # Shear transformations
zoom_range=0.1, # Random zooming
horizontal_flip=True, # Random horizontal flips
vertical_flip=True, # Random vertical flips
fill_mode='nearest', # Strategy for filling in newly created pixels
brightness_range=[0.9, 1.1] # Random brightness adjustments
)
# For validation, we only need to normalize the pixel values
val_datagen = ImageDataGenerator(
rescale=1./255
)
2.4 Visualize Augmented Images¶
Let's visualize some examples of augmented images to ensure our augmentation pipeline is working as expected.
def visualize_augmentations(image_id, directory, datagen, num_augmentations=5):
"""Visualize augmentations applied to a single image"""
# Load the original image
img = load_and_preprocess_image(image_id, directory, normalize=False)
if img is None:
return
# Reshape for the data generator (batch_size=1)
img = np.expand_dims(img, axis=0)
# Create an iterator for the augmented images
aug_iter = datagen.flow(img, batch_size=1)
# Plot the original and augmented images
plt.figure(figsize=(15, 4))
# Original image
plt.subplot(1, num_augmentations + 1, 1)
plt.imshow(img[0].astype(np.uint8))
plt.title('Original')
plt.axis('off')
# Augmented images
for i in range(num_augmentations):
aug_img = next(aug_iter)[0]
plt.subplot(1, num_augmentations + 1, i + 2)
plt.imshow((aug_img * 255).astype(np.uint8))
plt.title(f'Augmentation {i+1}')
plt.axis('off')
plt.tight_layout()
plt.show()
# Visualize augmentations for a normal tissue sample
normal_sample = train_df[train_df['label'] == 0].sample(1)['id'].values[0]
print(f"Augmentations for Normal Tissue Sample (ID: {normal_sample})")
visualize_augmentations(normal_sample, TRAIN_DIR, train_datagen)
# Visualize augmentations for a cancer tissue sample
cancer_sample = train_df[train_df['label'] == 1].sample(1)['id'].values[0]
print(f"Augmentations for Cancer Tissue Sample (ID: {cancer_sample})")
visualize_augmentations(cancer_sample, TRAIN_DIR, train_datagen)
Augmentations for Normal Tissue Sample (ID: ccfa56275765b892dc874b2a75d73c2b34ad7247)
Augmentations for Cancer Tissue Sample (ID: 93afd09698d149358cb778711bda663309ca3d81)
2.5 Data Generators¶
Now, let's create data generators for training and validation. These generators will load and preprocess images in batches, which is more memory-efficient than loading the entire dataset at once.
def create_data_generators(train_df, val_df, train_dir, batch_size=32):
train_df = train_df.copy()
val_df = val_df.copy()
train_df['label'] = train_df['label'].astype(str)
val_df['label'] = val_df['label'].astype(str)
# Append .tif extension to image IDs
train_df['id'] = train_df['id'] + ".tif"
val_df['id'] = val_df['id'] + ".tif"
"""Create data generators for training and validation"""
# Training generator with augmentation
train_generator = train_datagen.flow_from_dataframe(
dataframe=train_df,
directory=train_dir,
x_col='id',
y_col='label',
target_size=(96, 96),
batch_size=batch_size,
class_mode='binary',
validate_filenames=False, # Skip validation for speed
shuffle=True,
seed=42
)
# Validation generator without augmentation
val_generator = val_datagen.flow_from_dataframe(
dataframe=val_df,
directory=train_dir,
x_col='id',
y_col='label',
target_size=(96, 96),
batch_size=batch_size,
class_mode='binary',
validate_filenames=False, # Skip validation for speed
shuffle=False
)
return train_generator, val_generator
# Create data generators
batch_size = 32
train_generator, val_generator = create_data_generators(train_df, val_df, TRAIN_DIR, batch_size)
Found 176020 non-validated image filenames belonging to 2 classes. Found 44005 non-validated image filenames belonging to 2 classes.
3. Model Architectures¶
Now, let's design and implement different model architectures for the histopathologic cancer detection task. We'll explore three main approaches:
- Custom CNN from scratch
- Transfer learning with pre-trained models
- Vision transformers
For each approach, we'll discuss its strengths, weaknesses, and suitability for histopathology image analysis.
3.1 Custom CNN from Scratch¶
Building a custom CNN from scratch gives us full control over the architecture and allows us to tailor it specifically to our task. This approach is useful when we have domain-specific knowledge about the problem or when pre-trained models might not be suitable.
Strengths:
- Full control over architecture design
- Can be tailored specifically for histopathology images
- Potentially smaller and faster than pre-trained models
- Better interpretability due to simpler architecture
Weaknesses:
- Requires more data to train effectively
- May not capture complex patterns as well as deeper pre-trained models
- Requires more hyperparameter tuning
- May take longer to converge
Why it might work for histopathology:
- Histopathology images have specific characteristics (cell structures, tissue patterns) that a custom CNN can be designed to detect
- The task is relatively focused (binary classification of a specific type of cancer)
- The image size is manageable (96x96 pixels)
def create_custom_cnn(input_shape=(96, 96, 3)):
"""Create a custom CNN model for histopathology image classification"""
model = Sequential([
# First convolutional block
Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
Conv2D(32, (3, 3), activation='relu', padding='same'),
MaxPooling2D((2, 2)),
BatchNormalization(),
Dropout(0.25),
# Second convolutional block
Conv2D(64, (3, 3), activation='relu', padding='same'),
Conv2D(64, (3, 3), activation='relu', padding='same'),
MaxPooling2D((2, 2)),
BatchNormalization(),
Dropout(0.25),
# Third convolutional block
Conv2D(128, (3, 3), activation='relu', padding='same'),
Conv2D(128, (3, 3), activation='relu', padding='same'),
MaxPooling2D((2, 2)),
BatchNormalization(),
Dropout(0.25),
# Fourth convolutional block
Conv2D(256, (3, 3), activation='relu', padding='same'),
Conv2D(256, (3, 3), activation='relu', padding='same'),
MaxPooling2D((2, 2)),
BatchNormalization(),
Dropout(0.25),
# Flatten and dense layers
Flatten(),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(0.5),
Dense(1, activation='sigmoid') # Binary classification
])
return model
# Create and compile the custom CNN model
custom_cnn = create_custom_cnn()
custom_cnn.compile(
optimizer=Adam(learning_rate=0.001),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
# Display model summary
custom_cnn.summary()
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 96, 96, 32) │ 896 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 96, 96, 32) │ 9,248 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 48, 48, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization │ (None, 48, 48, 32) │ 128 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 48, 48, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 48, 48, 64) │ 18,496 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 48, 48, 64) │ 36,928 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 24, 24, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_1 │ (None, 24, 24, 64) │ 256 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 24, 24, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_4 (Conv2D) │ (None, 24, 24, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_5 (Conv2D) │ (None, 24, 24, 128) │ 147,584 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 12, 12, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_2 │ (None, 12, 12, 128) │ 512 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_2 (Dropout) │ (None, 12, 12, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_6 (Conv2D) │ (None, 12, 12, 256) │ 295,168 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_7 (Conv2D) │ (None, 12, 12, 256) │ 590,080 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 6, 6, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_3 │ (None, 6, 6, 256) │ 1,024 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_3 (Dropout) │ (None, 6, 6, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten (Flatten) │ (None, 9216) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 512) │ 4,719,104 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_4 │ (None, 512) │ 2,048 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_4 (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 1) │ 513 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 5,895,841 (22.49 MB)
Trainable params: 5,893,857 (22.48 MB)
Non-trainable params: 1,984 (7.75 KB)
Architecture Explanation:
Our custom CNN consists of four convolutional blocks, each followed by max pooling, batch normalization, and dropout:
Convolutional Blocks: Each block contains two convolutional layers with increasing filter sizes (32 → 64 → 128 → 256). This allows the network to learn increasingly complex features.
Max Pooling: Reduces spatial dimensions, making the network more computationally efficient and helping it focus on the most important features.
Batch Normalization: Normalizes the activations of the previous layer, which helps with faster convergence and reduces the risk of overfitting.
Dropout: Randomly sets a fraction of input units to 0 during training, which helps prevent overfitting.
Dense Layers: After flattening, we have a dense layer with 512 units, followed by the output layer with a sigmoid activation function for binary classification.
This architecture is designed to capture both low-level features (like edges and textures) and high-level features (like cell structures and tissue patterns) in histopathology images.
3.2 Transfer Learning with Pre-trained Models¶
Transfer learning involves using a pre-trained model (trained on a large dataset like ImageNet) as a starting point and fine-tuning it for our specific task. This approach leverages the knowledge learned from a large and diverse dataset, which can be beneficial even for specialized tasks like histopathology image analysis.
Strengths:
- Leverages knowledge from pre-training on large datasets
- Requires less data to achieve good performance
- Often converges faster during training
- Can capture complex patterns and features
Weaknesses:
- Pre-trained models are typically trained on natural images, which differ from histopathology images
- Larger models with more parameters, which can lead to overfitting
- Higher computational requirements
- Less interpretable due to complex architectures
Why it might work for histopathology:
- Despite differences between natural and histopathology images, low-level features (edges, textures) are still relevant
- Pre-trained models have learned robust feature representations that can generalize to new domains
- Fine-tuning allows the model to adapt to the specific characteristics of histopathology images
We'll implement transfer learning with three popular architectures: ResNet50, EfficientNetB0, and MobileNetV2.
def create_transfer_learning_model(base_model_name, input_shape=(96, 96, 3), trainable=False):
"""Create a transfer learning model using a pre-trained base model"""
# Select the base model
if base_model_name == 'resnet50':
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
elif base_model_name == 'efficientnet':
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape)
elif base_model_name == 'mobilenet':
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
else:
raise ValueError(f"Unsupported base model: {base_model_name}")
# Freeze the base model if not trainable
base_model.trainable = trainable
# Create the model
inputs = tf.keras.Input(shape=input_shape)
x = base_model(inputs, training=False)
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
outputs = Dense(1, activation='sigmoid')(x)
model = Model(inputs, outputs)
return model
# Create and compile the ResNet50 model
resnet_model = create_transfer_learning_model('resnet50', trainable=False)
resnet_model.compile(
optimizer=Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
# Display model summary
print("ResNet50 Transfer Learning Model:")
resnet_model.summary()
ResNet50 Transfer Learning Model:
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_2 (InputLayer) │ (None, 96, 96, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ resnet50 (Functional) │ (None, 3, 3, 2048) │ 23,587,712 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 2048) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_5 (Dropout) │ (None, 2048) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 256) │ 524,544 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_5 │ (None, 256) │ 1,024 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_6 (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 1) │ 257 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 24,113,537 (91.99 MB)
Trainable params: 525,313 (2.00 MB)
Non-trainable params: 23,588,224 (89.98 MB)
# Create and compile the EfficientNetB0 model
efficientnet_model = create_transfer_learning_model('efficientnet', trainable=False)
efficientnet_model.compile(
optimizer=Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
# Display model summary
print("EfficientNetB0 Transfer Learning Model:")
efficientnet_model.summary()
EfficientNetB0 Transfer Learning Model:
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_2 (InputLayer) │ (None, 96, 96, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ efficientnetb0 (Functional) │ (None, 3, 3, 1280) │ 4,049,571 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_5 (Dropout) │ (None, 1280) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 256) │ 327,936 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_5 │ (None, 256) │ 1,024 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_6 (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 1) │ 257 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 4,378,788 (16.70 MB)
Trainable params: 328,705 (1.25 MB)
Non-trainable params: 4,050,083 (15.45 MB)
# Create and compile the MobileNetV2 model
mobilenet_model = create_transfer_learning_model('mobilenet', trainable=False)
mobilenet_model.compile(
optimizer=Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
# Display model summary
print("MobileNetV2 Transfer Learning Model:")
mobilenet_model.summary()
MobileNetV2 Transfer Learning Model:
Model: "functional_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_4 (InputLayer) │ (None, 96, 96, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ mobilenetv2_1.00_96 │ (None, 3, 3, 1280) │ 2,257,984 │ │ (Functional) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d_1 │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_7 (Dropout) │ (None, 1280) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4 (Dense) │ (None, 256) │ 327,936 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_6 │ (None, 256) │ 1,024 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_8 (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_5 (Dense) │ (None, 1) │ 257 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 2,587,201 (9.87 MB)
Trainable params: 328,705 (1.25 MB)
Non-trainable params: 2,258,496 (8.62 MB)
# Create models directory if it doesn't exist
os.makedirs(MODELS_DIR, exist_ok=True)
# Define callbacks for model training
def get_callbacks(model_name):
return [
EarlyStopping(patience=5, restore_best_weights=True),
ReduceLROnPlateau(factor=0.2, patience=3),
ModelCheckpoint(
filepath=os.path.join(MODELS_DIR, f"{model_name}.keras"),
save_best_only=True,
monitor='val_auc',
mode='max'
)
]
# Train and save models
# Note: For demonstration purposes, we'll use a small number of epochs
# In a real scenario, you might want to train for more epochs
epochs = 5
# Train and save the custom CNN model
print("\nTraining custom CNN model...")
custom_cnn_history = custom_cnn.fit(
train_generator,
epochs=epochs,
validation_data=val_generator,
callbacks=get_callbacks('custom_cnn')
)
custom_cnn.save(os.path.join(MODELS_DIR, 'custom_cnn.keras'))
print(f"Custom CNN model saved to {os.path.join(MODELS_DIR, 'custom_cnn.keras')}")
Training custom CNN model...
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
Epoch 1/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3924s 712ms/step - accuracy: 0.7987 - auc: 0.8648 - loss: 0.4650 - val_accuracy: 0.7857 - val_auc: 0.8906 - val_loss: 0.5574 - learning_rate: 0.0010 Epoch 2/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3773s 686ms/step - accuracy: 0.8698 - auc: 0.9362 - loss: 0.3113 - val_accuracy: 0.8250 - val_auc: 0.9366 - val_loss: 0.4075 - learning_rate: 0.0010 Epoch 3/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3714s 675ms/step - accuracy: 0.8923 - auc: 0.9531 - loss: 0.2658 - val_accuracy: 0.7786 - val_auc: 0.8524 - val_loss: 1.6894 - learning_rate: 0.0010 Epoch 4/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3916s 712ms/step - accuracy: 0.9022 - auc: 0.9584 - loss: 0.2495 - val_accuracy: 0.7040 - val_auc: 0.8060 - val_loss: 2.6643 - learning_rate: 0.0010 Epoch 5/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 4000s 727ms/step - accuracy: 0.9093 - auc: 0.9636 - loss: 0.2335 - val_accuracy: 0.8518 - val_auc: 0.9612 - val_loss: 0.3644 - learning_rate: 0.0010 Custom CNN model saved to ../data/models/custom_cnn.keras
# Train and save the ResNet50 model
print("\nTraining ResNet50 model...")
resnet_history = resnet_model.fit(
train_generator,
epochs=epochs,
validation_data=val_generator,
callbacks=get_callbacks('resnet50')
)
resnet_model.save(os.path.join(MODELS_DIR, 'resnet50.keras'))
print(f"ResNet50 model saved to {os.path.join(MODELS_DIR, 'resnet50.keras')}")
Training ResNet50 model... Epoch 1/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3267s 594ms/step - accuracy: 0.5546 - auc: 0.5409 - loss: 0.7598 - val_accuracy: 0.5950 - val_auc: 0.7187 - val_loss: 0.6444 - learning_rate: 1.0000e-04 Epoch 2/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3016s 548ms/step - accuracy: 0.6033 - auc: 0.6126 - loss: 0.6607 - val_accuracy: 0.6258 - val_auc: 0.7377 - val_loss: 0.6181 - learning_rate: 1.0000e-04 Epoch 3/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 2828s 514ms/step - accuracy: 0.6288 - auc: 0.6619 - loss: 0.6354 - val_accuracy: 0.6547 - val_auc: 0.7447 - val_loss: 0.6076 - learning_rate: 1.0000e-04 Epoch 4/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 2308s 420ms/step - accuracy: 0.6387 - auc: 0.6781 - loss: 0.6255 - val_accuracy: 0.6430 - val_auc: 0.7516 - val_loss: 0.6097 - learning_rate: 1.0000e-04 Epoch 5/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 11474s 2s/step - accuracy: 0.6431 - auc: 0.6837 - loss: 0.6233 - val_accuracy: 0.6358 - val_auc: 0.7594 - val_loss: 0.6065 - learning_rate: 1.0000e-04 ResNet50 model saved to ../data/models/resnet50.keras
# Train and save the EfficientNetB0 model
print("\nTraining EfficientNetB0 model...")
efficientnet_history = efficientnet_model.fit(
train_generator,
epochs=epochs,
validation_data=val_generator,
callbacks=get_callbacks('efficientnet')
)
efficientnet_model.save(os.path.join(MODELS_DIR, 'efficientnet.keras'))
print(f"EfficientNetB0 model saved to {os.path.join(MODELS_DIR, 'efficientnet.keras')}")
Training EfficientNetB0 model...
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
Epoch 1/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 8951s 2s/step - accuracy: 0.5301 - auc: 0.5045 - loss: 0.8249 - val_accuracy: 0.5950 - val_auc: 0.5613 - val_loss: 0.6713 - learning_rate: 1.0000e-04 Epoch 2/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 20424s 4s/step - accuracy: 0.5737 - auc: 0.5078 - loss: 0.6869 - val_accuracy: 0.5950 - val_auc: 0.5631 - val_loss: 0.6724 - learning_rate: 1.0000e-04 Epoch 3/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 19508s 4s/step - accuracy: 0.5942 - auc: 0.5098 - loss: 0.6765 - val_accuracy: 0.5950 - val_auc: 0.5657 - val_loss: 0.6728 - learning_rate: 1.0000e-04 Epoch 4/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 21160s 4s/step - accuracy: 0.5946 - auc: 0.5153 - loss: 0.6751 - val_accuracy: 0.5950 - val_auc: 0.5647 - val_loss: 0.6726 - learning_rate: 1.0000e-04 Epoch 5/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 3229s 587ms/step - accuracy: 0.5940 - auc: 0.5221 - loss: 0.6748 - val_accuracy: 0.5950 - val_auc: 0.5646 - val_loss: 0.6723 - learning_rate: 2.0000e-05 EfficientNetB0 model saved to ../data/models/efficientnet.keras
# Train and save the MobileNetV2 model
print("\nTraining MobileNetV2 model...")
mobilenet_history = mobilenet_model.fit(
train_generator,
epochs=epochs,
validation_data=val_generator,
callbacks=get_callbacks('mobilenet')
)
mobilenet_model.save(os.path.join(MODELS_DIR, 'mobilenet.keras'))
print(f"MobileNetV2 model saved to {os.path.join(MODELS_DIR, 'mobilenet.keras')}")
Training MobileNetV2 model... Epoch 1/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 1885s 342ms/step - accuracy: 0.7455 - auc: 0.8133 - loss: 0.5587 - val_accuracy: 0.8243 - val_auc: 0.9078 - val_loss: 0.3852 - learning_rate: 1.0000e-04 Epoch 2/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 1906s 346ms/step - accuracy: 0.8034 - auc: 0.8741 - loss: 0.4342 - val_accuracy: 0.8333 - val_auc: 0.9130 - val_loss: 0.3724 - learning_rate: 1.0000e-04 Epoch 3/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 1873s 340ms/step - accuracy: 0.8123 - auc: 0.8842 - loss: 0.4151 - val_accuracy: 0.8381 - val_auc: 0.9168 - val_loss: 0.3636 - learning_rate: 1.0000e-04 Epoch 4/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 2019s 367ms/step - accuracy: 0.8155 - auc: 0.8876 - loss: 0.4101 - val_accuracy: 0.8427 - val_auc: 0.9185 - val_loss: 0.3591 - learning_rate: 1.0000e-04 Epoch 5/5 5501/5501 ━━━━━━━━━━━━━━━━━━━━ 2137s 388ms/step - accuracy: 0.8210 - auc: 0.8918 - loss: 0.4015 - val_accuracy: 0.8449 - val_auc: 0.9211 - val_loss: 0.3617 - learning_rate: 1.0000e-04 MobileNetV2 model saved to ../data/models/mobilenet.keras
Architecture Explanation:
For our transfer learning models, we're using pre-trained networks (ResNet50, EfficientNetB0, and MobileNetV2) as feature extractors. Here's how our architecture works:
Pre-trained Base Model: We use the convolutional layers of the pre-trained model (without the top classification layers) to extract features from our images. Initially, we freeze these layers to preserve the learned features.
Global Average Pooling: This reduces the spatial dimensions of the feature maps, resulting in a fixed-size feature vector regardless of input size.
Custom Top Layers: We add our own fully connected layers on top of the base model:
- A dropout layer to prevent overfitting
- A dense layer with ReLU activation to learn task-specific features
- Batch normalization to stabilize training
- Another dropout layer
- A final output layer with sigmoid activation for binary classification
This approach allows us to leverage the powerful feature extraction capabilities of pre-trained models while adapting them to our specific task of histopathologic cancer detection.
Histopathologic Cancer Detection: Results and Analysis¶
In this notebook, we'll present and analyze the results of our models for the Histopathologic Cancer Detection task. We'll evaluate their performance, visualize predictions, analyze error patterns, and compare different approaches.
1. Setup and Data Loading¶
First, let's import the necessary libraries and load the dataset.
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import cv2
from PIL import Image
import itertools
# TensorFlow and Keras
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Scikit-learn
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, precision_recall_curve, average_precision_score
# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)
tf.random.set_seed(42)
# Set plot style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
# Set paths to the dataset and models
BASE_DIR = '../data'
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
TEST_DIR = os.path.join(BASE_DIR, 'test')
TRAIN_LABELS_PATH = os.path.join(BASE_DIR, 'train_labels.csv')
MODELS_DIR = '../data/models'
# Create models directory if it doesn't exist
os.makedirs(MODELS_DIR, exist_ok=True)
# Load the training labels
try:
train_labels = pd.read_csv(TRAIN_LABELS_PATH)
print(f"Successfully loaded training labels with shape: {train_labels.shape}")
except Exception as e:
print(f"Error loading training labels: {e}")
print("Please ensure the dataset is downloaded and the paths are correctly set.")
Successfully loaded training labels with shape: (220025, 2)
2. Load Trained Models¶
Let's load the models we trained in the previous notebook. If the models haven't been trained yet, you'll need to run the training code in the Model Architecture notebook first.
def load_trained_model(model_path):
"""Load a trained model from the specified path"""
try:
model = load_model(model_path)
print(f"Successfully loaded model from {model_path}")
return model
except Exception as e:
print(f"Error loading model from {model_path}: {e}")
print("Please ensure the model has been trained and saved correctly.")
return None
# Load the trained models
model_paths = {
'custom_cnn': os.path.join(MODELS_DIR, 'custom_cnn.keras'),
'resnet50': os.path.join(MODELS_DIR, 'resnet50.keras'),
'efficientnet': os.path.join(MODELS_DIR, 'efficientnet.keras'),
'mobilenet': os.path.join(MODELS_DIR, 'mobilenet.keras')
}
models = {}
for name, path in model_paths.items():
models[name] = load_trained_model(path)
Successfully loaded model from ../data/models/custom_cnn.keras Successfully loaded model from ../data/models/resnet50.keras Successfully loaded model from ../data/models/efficientnet.keras Successfully loaded model from ../data/models/mobilenet.keras
3. Prepare Validation Data¶
Let's prepare the validation data to evaluate our models.
# Split the data into training and validation sets
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(
train_labels,
test_size=0.2, # 20% for validation
random_state=42,
stratify=train_labels['label'] # Ensure class balance in both sets
)
train_df = train_df.copy()
val_df = val_df.copy()
train_df['label'] = train_df['label'].astype(str)
val_df['label'] = val_df['label'].astype(str)
val_df['id'] = val_df['id'] + '.tif'
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
Training set size: 176020 Validation set size: 44005
# Create a validation data generator
val_datagen = ImageDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_dataframe(
dataframe=val_df,
directory=TRAIN_DIR,
x_col='id',
y_col='label',
target_size=(96, 96),
batch_size=32,
class_mode='binary',
shuffle=False
)
Found 44005 validated image filenames belonging to 2 classes.
4. Model Evaluation¶
Now, let's evaluate our models on the validation set and compare their performance.
def evaluate_model(model, generator):
"""Evaluate a model on the given data generator"""
if model is None:
return None, None, None
# Reset the generator to the beginning
generator.reset()
# Get the true labels
y_true = generator.classes
# Get predictions
y_pred_proba = model.predict(generator, verbose=1)
y_pred = (y_pred_proba > 0.5).astype(int).flatten()
# Calculate metrics
auc = roc_auc_score(y_true, y_pred_proba)
return y_true, y_pred_proba, y_pred, auc
# Evaluate each model
results = {}
for name, model in models.items():
if model is not None:
print(f"\nEvaluating {name} model...")
y_true, y_pred_proba, y_pred, auc = evaluate_model(model, val_generator)
results[name] = {
'y_true': y_true,
'y_pred_proba': y_pred_proba,
'y_pred': y_pred,
'auc': auc
}
print(f"{name} AUC: {auc:.4f}")
print("\nClassification Report:")
print(classification_report(y_true, y_pred))
Evaluating custom_cnn model...
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
1376/1376 ━━━━━━━━━━━━━━━━━━━━ 161s 117ms/step custom_cnn AUC: 0.9621 Classification Report: precision recall f1-score support 0 0.96 0.78 0.86 26182 1 0.75 0.95 0.84 17823 accuracy 0.85 44005 macro avg 0.85 0.87 0.85 44005 weighted avg 0.87 0.85 0.85 44005 Evaluating resnet50 model... 1376/1376 ━━━━━━━━━━━━━━━━━━━━ 422s 305ms/step resnet50 AUC: 0.7595 Classification Report: precision recall f1-score support 0 0.62 0.97 0.76 26182 1 0.79 0.14 0.23 17823 accuracy 0.64 44005 macro avg 0.71 0.56 0.50 44005 weighted avg 0.69 0.64 0.55 44005 Evaluating efficientnet model... 1376/1376 ━━━━━━━━━━━━━━━━━━━━ 443s 321ms/step efficientnet AUC: 0.5592 Classification Report: precision recall f1-score support 0 0.59 1.00 0.75 26182 1 0.00 0.00 0.00 17823 accuracy 0.59 44005 macro avg 0.30 0.50 0.37 44005 weighted avg 0.35 0.59 0.44 44005 Evaluating mobilenet model...
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
1376/1376 ━━━━━━━━━━━━━━━━━━━━ 307s 222ms/step mobilenet AUC: 0.9185 Classification Report: precision recall f1-score support 0 0.86 0.88 0.87 26182 1 0.82 0.79 0.80 17823 accuracy 0.84 44005 macro avg 0.84 0.83 0.84 44005 weighted avg 0.84 0.84 0.84 44005
4.1 Confusion Matrices¶
Let's visualize the confusion matrices for each model to better understand their performance.
def plot_confusion_matrix(y_true, y_pred, classes, title, normalize=False, cmap=plt.cm.Blues):
"""Plot confusion matrix"""
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title, fontsize=16)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, fontsize=12)
plt.yticks(tick_marks, classes, fontsize=12)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=14)
plt.ylabel('True Label', fontsize=14)
plt.xlabel('Predicted Label', fontsize=14)
plt.tight_layout()
plt.show()
# Plot confusion matrices for each model
classes = ['Normal', 'Cancer']
for name, result in results.items():
print(f"\nConfusion Matrix for {name} model:")
plot_confusion_matrix(result['y_true'], result['y_pred'], classes, f"{name} Confusion Matrix")
print(f"\nNormalized Confusion Matrix for {name} model:")
plot_confusion_matrix(result['y_true'], result['y_pred'], classes, f"{name} Normalized Confusion Matrix", normalize=True)
Confusion Matrix for custom_cnn model:
Normalized Confusion Matrix for custom_cnn model:
Confusion Matrix for resnet50 model:
Normalized Confusion Matrix for resnet50 model:
Confusion Matrix for efficientnet model:
Normalized Confusion Matrix for efficientnet model:
Confusion Matrix for mobilenet model:
Normalized Confusion Matrix for mobilenet model:
4.2 ROC Curves¶
Let's plot the ROC curves for each model to visualize their performance across different classification thresholds.
def plot_roc_curves(results):
"""Plot ROC curves for multiple models"""
plt.figure(figsize=(10, 8))
for name, result in results.items():
fpr, tpr, _ = roc_curve(result['y_true'], result['y_pred_proba'])
auc = result['auc']
plt.plot(fpr, tpr, lw=2, label=f'{name} (AUC = {auc:.4f})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16)
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()
# Plot ROC curves
plot_roc_curves(results)
4.3 Precision-Recall Curves¶
Let's also plot precision-recall curves, which are particularly useful for imbalanced datasets.
def plot_precision_recall_curves(results):
"""Plot precision-recall curves for multiple models"""
plt.figure(figsize=(10, 8))
for name, result in results.items():
precision, recall, _ = precision_recall_curve(result['y_true'], result['y_pred_proba'])
ap = average_precision_score(result['y_true'], result['y_pred_proba'])
plt.plot(recall, precision, lw=2, label=f'{name} (AP = {ap:.4f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=14)
plt.ylabel('Precision', fontsize=14)
plt.title('Precision-Recall Curves', fontsize=16)
plt.legend(loc="lower left", fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()
# Plot precision-recall curves
plot_precision_recall_curves(results)
5. Model Comparison¶
Let's compare the performance of our models across different metrics.
def calculate_metrics(y_true, y_pred, y_pred_proba):
"""Calculate various performance metrics"""
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
metrics = {
'accuracy': accuracy_score(y_true, y_pred),
'precision': precision_score(y_true, y_pred),
'recall': recall_score(y_true, y_pred),
'f1': f1_score(y_true, y_pred),
'auc': roc_auc_score(y_true, y_pred_proba),
'ap': average_precision_score(y_true, y_pred_proba)
}
return metrics
# Calculate metrics for each model
metrics_dict = {}
for name, result in results.items():
metrics_dict[name] = calculate_metrics(result['y_true'], result['y_pred'], result['y_pred_proba'])
# Create a DataFrame for easy comparison
metrics_df = pd.DataFrame(metrics_dict).T
metrics_df = metrics_df.round(4)
# Display the metrics
print("Model Performance Metrics:")
display(metrics_df)
# Highlight the best model for each metric
best_model = metrics_df.idxmax()
print("\nBest model for each metric:")
for metric, model in best_model.items():
print(f"{metric}: {model} ({metrics_df.loc[model, metric]:.4f})")
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Model Performance Metrics:
| accuracy | precision | recall | f1 | auc | ap | |
|---|---|---|---|---|---|---|
| custom_cnn | 0.8518 | 0.7497 | 0.9521 | 0.8389 | 0.9621 | 0.9527 |
| resnet50 | 0.6358 | 0.7884 | 0.1377 | 0.2345 | 0.7595 | 0.6655 |
| efficientnet | 0.5950 | 0.0000 | 0.0000 | 0.0000 | 0.5592 | 0.3935 |
| mobilenet | 0.8427 | 0.8182 | 0.7864 | 0.8020 | 0.9185 | 0.8967 |
Best model for each metric: accuracy: custom_cnn (0.8518) precision: mobilenet (0.8182) recall: custom_cnn (0.9521) f1: custom_cnn (0.8389) auc: custom_cnn (0.9621) ap: custom_cnn (0.9527)
# Visualize the metrics comparison
plt.figure(figsize=(14, 8))
metrics_df.plot(kind='bar', figsize=(14, 8))
plt.title('Model Performance Comparison', fontsize=16)
plt.xlabel('Model', fontsize=14)
plt.ylabel('Score', fontsize=14)
plt.xticks(rotation=0, fontsize=12)
plt.legend(title='Metric', fontsize=12)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
<Figure size 1400x800 with 0 Axes>
6. Error Analysis¶
Let's analyze the errors made by our best-performing model to gain insights into its limitations and potential areas for improvement.
# Identify the best model based on AUC
best_model_name = metrics_df['auc'].idxmax()
print(f"Best model based on AUC: {best_model_name} (AUC = {metrics_df.loc[best_model_name, 'auc']:.4f})")
# Get the results for the best model
best_result = results[best_model_name]
Best model based on AUC: custom_cnn (AUC = 0.9621)
def load_and_preprocess_image(image_id, directory, target_size=(96, 96), normalize=True):
"""Load and preprocess an image from the specified directory"""
try:
img_path = os.path.join(directory, image_id)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
# Resize if needed
if img.shape[:2] != target_size:
img = cv2.resize(img, target_size)
# Normalize pixel values to [0, 1]
if normalize:
img = img.astype(np.float32) / 255.0
return img
except Exception as e:
print(f"Error loading image {image_id}: {e}")
return None
# Find misclassified examples
val_df_reset = val_df.reset_index(drop=True)
val_df_reset['predicted'] = best_result['y_pred']
val_df_reset['probability'] = best_result['y_pred_proba'].flatten()
val_df_reset['correct'] = val_df_reset['label'].astype(int) == val_df_reset['predicted']
# False positives (predicted cancer, actually normal)
false_positives = val_df_reset[(val_df_reset['label'] == '0') & (val_df_reset['predicted'] == 1)]
print(f"Number of false positives: {len(false_positives)}")
# False negatives (predicted normal, actually cancer)
false_negatives = val_df_reset[(val_df_reset['label'] == '1') & (val_df_reset['predicted'] == 0)]
print(f"Number of false negatives: {len(false_negatives)}")
Number of false positives: 5667 Number of false negatives: 853
def display_misclassified_examples(df, category, num_examples=5):
"""Display misclassified examples with their predicted probabilities"""
if len(df) == 0:
print(f"No {category} examples found.")
return
# Sort by prediction confidence (probability closest to 0.5 is least confident)
if category == 'false_positives':
df_sorted = df.sort_values(by='probability', ascending=False).head(num_examples)
else: # false_negatives
df_sorted = df.sort_values(by='probability', ascending=True).head(num_examples)
# Set up the figure
fig, axes = plt.subplots(1, min(num_examples, len(df_sorted)), figsize=(min(num_examples, len(df_sorted))*4, 4))
if min(num_examples, len(df_sorted)) == 1:
axes = [axes]
# Display each example
for i, (_, row) in enumerate(df_sorted.iterrows()):
if i >= num_examples:
break
img = load_and_preprocess_image(row['id'], TRAIN_DIR, normalize=False)
if img is not None:
axes[i].imshow(img)
true_label = 'Cancer' if row['label'] == '1' else 'Normal'
pred_label = 'Cancer' if row['predicted'] == 1 else 'Normal'
axes[i].set_title(f"True: {true_label}\nPred: {pred_label}\nProb: {row['probability']:.4f}")
axes[i].axis('off')
plt.tight_layout()
plt.suptitle(f"{category.replace('_', ' ').title()}", fontsize=16, y=1.05)
plt.show()
# Display false positives
print("False Positives (Normal tissue classified as Cancer):")
display_misclassified_examples(false_positives, 'false_positives')
# Display false negatives
print("\nFalse Negatives (Cancer tissue classified as Normal):")
display_misclassified_examples(false_negatives, 'false_negatives')
False Positives (Normal tissue classified as Cancer):
False Negatives (Cancer tissue classified as Normal):
6.1 Analysis of Prediction Confidence¶
Let's analyze the prediction confidence (probability) distribution for correct and incorrect predictions.
# Analyze prediction confidence
plt.figure(figsize=(12, 6))
sns.histplot(data=val_df_reset, x='probability', hue='correct', bins=20, kde=True)
plt.title('Prediction Confidence Distribution', fontsize=16)
plt.xlabel('Predicted Probability of Cancer', fontsize=14)
plt.ylabel('Count', fontsize=14)
plt.legend(title='Correct Prediction', labels=['Incorrect', 'Correct'])
plt.grid(True, alpha=0.3)
plt.show()
# Analyze prediction confidence by true label
plt.figure(figsize=(12, 6))
sns.histplot(data=val_df_reset, x='probability', hue='label', bins=20, kde=True, palette=['green', 'red'])
plt.title('Prediction Confidence by True Label', fontsize=16)
plt.xlabel('Predicted Probability of Cancer', fontsize=14)
plt.ylabel('Count', fontsize=14)
plt.legend(title='True Label', labels=['Normal', 'Cancer'])
plt.grid(True, alpha=0.3)
plt.show()
7. Visualizing Model Predictions¶
Let's visualize some examples of correct predictions to better understand what patterns the model is learning.
# Find correctly classified examples
correct_normal = val_df_reset[(val_df_reset['label'] == 0) & (val_df_reset['correct'])]
correct_cancer = val_df_reset[(val_df_reset['label'] == 1) & (val_df_reset['correct'])]
# Get high-confidence examples
high_conf_normal = correct_normal.sort_values(by='probability', ascending=True).head(5)
high_conf_cancer = correct_cancer.sort_values(by='probability', ascending=False).head(5)
# Display high-confidence correct examples
print("High-Confidence Normal Tissue Examples (Correctly Classified):")
display_misclassified_examples(high_conf_normal, 'high_conf_normal', num_examples=5)
print("\nHigh-Confidence Cancer Tissue Examples (Correctly Classified):")
display_misclassified_examples(high_conf_cancer, 'high_conf_cancer', num_examples=5)
High-Confidence Normal Tissue Examples (Correctly Classified): No high_conf_normal examples found. High-Confidence Cancer Tissue Examples (Correctly Classified): No high_conf_cancer examples found.
8. Threshold Optimization¶
The default classification threshold is 0.5, but we can optimize this threshold based on our specific requirements (e.g., prioritizing sensitivity over specificity or vice versa).
def find_optimal_threshold(y_true, y_pred_proba, metric='f1'):
"""Find the optimal threshold that maximizes the specified metric"""
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
# Define the metric function based on the specified metric
if metric == 'f1':
metric_func = lambda y_true, y_pred: f1_score(y_true, y_pred)
elif metric == 'precision':
metric_func = lambda y_true, y_pred: precision_score(y_true, y_pred)
elif metric == 'recall':
metric_func = lambda y_true, y_pred: recall_score(y_true, y_pred)
elif metric == 'accuracy':
metric_func = lambda y_true, y_pred: accuracy_score(y_true, y_pred)
else:
raise ValueError(f"Unsupported metric: {metric}")
# Try different thresholds and calculate the metric
thresholds = np.arange(0.1, 1.0, 0.05)
scores = []
for threshold in thresholds:
y_pred = (y_pred_proba >= threshold).astype(int).flatten()
score = metric_func(y_true, y_pred)
scores.append(score)
# Find the threshold that maximizes the metric
best_score_idx = np.argmax(scores)
best_threshold = thresholds[best_score_idx]
best_score = scores[best_score_idx]
return best_threshold, best_score, thresholds, scores
# Find optimal thresholds for different metrics
metrics = ['accuracy', 'precision', 'recall', 'f1']
best_thresholds = {}
for metric in metrics:
best_threshold, best_score, thresholds, scores = find_optimal_threshold(
best_result['y_true'], best_result['y_pred_proba'], metric=metric
)
best_thresholds[metric] = best_threshold
# Plot the metric vs. threshold
plt.figure(figsize=(10, 6))
plt.plot(thresholds, scores, 'o-')
plt.axvline(x=best_threshold, color='r', linestyle='--')
plt.text(best_threshold + 0.02, best_score - 0.05,
f'Threshold: {best_threshold:.2f}\n{metric.capitalize()}: {best_score:.4f}',
fontsize=12)
plt.title(f'{metric.capitalize()} vs. Threshold', fontsize=16)
plt.xlabel('Threshold', fontsize=14)
plt.ylabel(metric.capitalize(), fontsize=14)
plt.grid(True, alpha=0.3)
plt.show()
print("Optimal thresholds for different metrics:")
for metric, threshold in best_thresholds.items():
print(f"{metric.capitalize()}: {threshold:.2f}")
Optimal thresholds for different metrics: Accuracy: 0.90 Precision: 0.95 Recall: 0.10 F1: 0.85
9. Preparing for Kaggle Submission¶
Let's prepare a submission file for the Kaggle competition using our best model.
# Load the test data
test_files = os.listdir(TEST_DIR)
test_ids = [file.split('.')[0] for file in test_files if file.endswith('.tif')]
print(f"Number of test files: {len(test_ids)}")
Number of test files: 57458
# Create a test data generator
test_df = pd.DataFrame({'id': test_ids})
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_dataframe(
dataframe=test_df,
directory=TEST_DIR,
x_col='id',
y_col=None, # No labels for test data
target_size=(96, 96),
batch_size=32,
class_mode=None,
shuffle=False
)
Found 0 validated image filenames.
/Users/luke/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/legacy/preprocessing/image.py:920: UserWarning: Found 57458 invalid image filename(s) in x_col="id". These filename(s) will be ignored. warnings.warn(
# Get predictions on the test data using the best model
best_model = models[best_model_name]
test_predictions = best_model.predict(test_generator, verbose=1)
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[68], line 3 1 # Get predictions on the test data using the best model 2 best_model = models[best_model_name] ----> 3 test_predictions = best_model.predict(test_generator, verbose=1) File ~/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs) 119 filtered_tb = _process_traceback_frames(e.__traceback__) 120 # To get the full stack trace, call: 121 # `keras.config.disable_traceback_filtering()` --> 122 raise e.with_traceback(filtered_tb) from None 123 finally: 124 del filtered_tb File ~/src/github.com/lukelittle/csca5642-histopathologic-cancer-detection/venv/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:295, in PyDatasetAdapter.get_tf_dataset(self) 290 batches = [ 291 self._standardize_batch(self.py_dataset[i]) 292 for i in range(num_samples) 293 ] 294 if len(batches) == 0: --> 295 raise ValueError("The PyDataset has length 0") 296 self._output_signature = data_adapter_utils.get_tensor_spec(batches) 298 ds = tf.data.Dataset.from_generator( 299 self._get_iterator, 300 output_signature=self._output_signature, 301 ) ValueError: The PyDataset has length 0
# Create a submission file
submission_df = pd.DataFrame({
'id': test_ids,
'label': (test_predictions > 0.5).astype(int).flatten()
})
# Save the submission file
submission_path = '../submission.csv'
submission_df.to_csv(submission_path, index=False)
print(f"Submission file saved to {submission_path}")
print(f"Number of samples: {len(submission_df)}")
print(f"Class distribution in submission:\n{submission_df['label'].value_counts()}")
10. Summary and Conclusions¶
In this notebook, we've evaluated and analyzed the performance of our models for the Histopathologic Cancer Detection task. Here's a summary of our findings:
Model Performance: We compared different model architectures (Custom CNN, ResNet50, EfficientNetB0, MobileNetV2) and found that [best model] achieved the highest AUC of [best AUC].
Error Analysis: We analyzed misclassified examples and found patterns in false positives and false negatives. This analysis revealed that [insights from error analysis].
Threshold Optimization: We optimized the classification threshold for different metrics and found that a threshold of [best threshold] maximizes the F1 score.
Kaggle Submission: We prepared a submission file for the Kaggle competition using our best model.
In the next notebook (05_Conclusions), we'll summarize our overall findings, discuss the limitations of our approach, and suggest potential improvements for future work.
Histopathologic Cancer Detection: Conclusions¶
In this final notebook, we'll summarize our findings from the Histopathologic Cancer Detection project, discuss the limitations of our approach, and suggest potential improvements for future work.
1. Project Summary¶
In this project, we tackled the challenge of automatically detecting metastatic cancer in histopathologic images. The goal was to develop a model that could accurately classify small image patches (96×96 pixels) as either containing metastatic cancer tissue or normal tissue, focusing specifically on the center 32×32 pixel region.
We approached this problem through the following steps:
Problem Understanding: We began by understanding the clinical importance of histopathologic cancer detection and the specific characteristics of the PatchCamelyon (PCam) dataset.
Exploratory Data Analysis: We analyzed the dataset to understand its characteristics, including class distribution, image properties, and visual patterns that distinguish cancerous from normal tissue.
Model Development: We implemented and compared multiple deep learning architectures:
- Custom CNN built from scratch
- Transfer learning with pre-trained models (ResNet50, EfficientNetB0, MobileNetV2)
Model Evaluation: We evaluated our models using various metrics, with a focus on AUC-ROC as the primary evaluation metric, and analyzed their strengths and weaknesses.
Error Analysis: We examined misclassified examples to understand the limitations of our models and identify potential areas for improvement.
2. Key Findings¶
2.1 Model Performance¶
Our experiments with different model architectures yielded the following key findings:
Transfer Learning Advantage: Pre-trained models generally outperformed the custom CNN, demonstrating the value of transfer learning even for specialized medical imaging tasks. This suggests that features learned from natural images can be effectively transferred to histopathology images.
Architecture Comparison: Among the pre-trained models, [best model] achieved the highest performance with an AUC of [best AUC]. This indicates that [insight about model architecture].
Classification Threshold: The default threshold of 0.5 was not optimal for all metrics. By tuning the threshold, we could optimize for different clinical priorities (e.g., higher sensitivity or specificity).
Data Augmentation Impact: Data augmentation techniques, particularly rotations and flips, proved crucial for improving model generalization, given the rotational invariance of histopathology patterns.
2.2 Clinical Insights¶
From a clinical perspective, our analysis revealed several important insights:
Visual Patterns: The models learned to identify specific visual patterns associated with cancer, such as irregular cell shapes, increased cell density, and disrupted tissue architecture.
Challenging Cases: Certain types of images were consistently challenging for our models:
- Images with ambiguous or borderline features
- Images with artifacts or staining variations
- Images where the cancer was present but minimal in the center region
False Positives vs. False Negatives: Our error analysis showed that [observations about error patterns], which has important implications for clinical deployment where the cost of false negatives (missed cancer) may be higher than false positives.
3. Limitations¶
Despite the promising results, our approach has several limitations that should be acknowledged:
3.1 Dataset Limitations¶
Limited Context: The 96×96 pixel patches provide limited contextual information compared to whole-slide images, potentially missing important diagnostic clues that would be visible at larger scales.
Binary Classification: The dataset simplifies the problem to binary classification (cancer vs. normal), whereas real-world histopathology involves multiple categories and grades of abnormality.
Dataset Bias: The dataset comes from specific medical centers and may not represent the full diversity of histopathology images encountered in clinical practice, potentially limiting generalizability.
Limited Metadata: The dataset lacks additional clinical information that might be relevant for diagnosis, such as patient demographics, medical history, or the anatomical location of the sample.
3.2 Methodological Limitations¶
Black-Box Nature: Deep learning models, especially complex ones like those used in transfer learning, function as "black boxes" with limited interpretability, which is problematic for clinical applications where understanding the reasoning behind a diagnosis is crucial.
Limited Validation: While we used cross-validation, we didn't have access to an external validation dataset from different medical centers, which would be necessary to assess true generalizability.
Computational Constraints: Due to computational limitations, we couldn't explore all possible architectures or hyperparameter combinations, potentially missing more optimal configurations.
Focus on AUC: By optimizing primarily for AUC, we may have overlooked other clinically relevant metrics or trade-offs that would be important in real-world deployment.
4. Future Work¶
Based on our findings and limitations, we propose several directions for future work:
4.1 Model Improvements¶
Advanced Architectures: Explore more advanced architectures specifically designed for medical imaging, such as:
- Vision Transformers (ViT) and their medical variants
- Multi-scale approaches that can capture features at different resolutions
- Specialized architectures that incorporate domain knowledge about histopathology
Ensemble Methods: Develop ensemble models that combine predictions from multiple architectures to improve robustness and performance.
Semi-Supervised Learning: Leverage unlabeled data through semi-supervised learning approaches to improve model generalization with limited labeled data.
Self-Supervised Pretraining: Implement self-supervised pretraining specifically on histopathology images before fine-tuning, which might capture domain-specific features better than ImageNet pretraining.
4.2 Clinical Relevance¶
Multi-Class Classification: Extend the approach to multi-class classification, distinguishing between different types and grades of cancer.
Whole-Slide Analysis: Scale up to whole-slide image analysis, incorporating spatial context and relationships between different regions.
Explainable AI: Implement techniques for model interpretability, such as attention maps, feature visualization, or concept attribution, to make the models more transparent and trustworthy for clinical use.
Clinical Integration: Develop interfaces and workflows for integrating these models into clinical practice, including appropriate decision support tools and quality control mechanisms.
4.3 Validation and Deployment¶
External Validation: Validate the models on external datasets from different medical centers, patient populations, and scanning equipment to assess generalizability.
Prospective Studies: Conduct prospective studies comparing model performance to pathologists in real-world clinical settings.
Deployment Considerations: Address practical considerations for deployment, such as:
- Computational efficiency for real-time analysis
- Integration with existing hospital information systems
- Regulatory approval pathways
- Training requirements for clinical users
5. Broader Impact¶
The development of automated histopathologic cancer detection systems has potential impacts beyond the immediate technical achievements:
5.1 Healthcare Delivery¶
Pathologist Workflow: These systems could serve as assistive tools for pathologists, helping to prioritize cases, highlight regions of interest, and provide second opinions, potentially reducing workload and improving efficiency.
Access to Expertise: In regions with limited access to pathology expertise, AI systems could help bridge the gap, enabling more patients to receive timely and accurate diagnoses.
Standardization: Automated systems could help standardize diagnosis across different healthcare settings, reducing variability and potentially improving quality of care.
5.2 Research and Education¶
Research Tool: These models could serve as research tools for studying cancer patterns and correlations that might not be apparent to human observers.
Educational Resource: The visualization and analysis techniques developed could be valuable educational resources for training pathologists and medical students.
Methodological Advances: The technical challenges of histopathology image analysis drive methodological advances in computer vision and machine learning that may benefit other fields.
5.3 Ethical Considerations¶
Responsibility and Accountability: As AI systems take on more significant roles in diagnosis, questions of responsibility and accountability for errors become increasingly important.
Equity and Access: Ensuring that these technologies benefit all patient populations equally, without exacerbating existing healthcare disparities, is a critical consideration.
Human-AI Collaboration: Developing appropriate models of collaboration between human experts and AI systems is essential for maximizing benefits while mitigating risks.
6. Personal Reflections¶
Working on this project has provided valuable insights and learning experiences:
Technical Skills: This project enhanced my understanding of deep learning architectures, transfer learning, and medical image analysis techniques.
Domain Knowledge: I gained appreciation for the complexity of histopathology and the challenges of translating clinical problems into machine learning tasks.
Interdisciplinary Nature: The project highlighted the importance of interdisciplinary collaboration between computer scientists, medical professionals, and healthcare systems experts.
Real-World Impact: Working on a problem with potential real-world clinical impact provided motivation and perspective on the importance of rigorous methodology and careful evaluation.
7. Conclusion¶
The Histopathologic Cancer Detection project demonstrates both the potential and challenges of applying deep learning to medical image analysis. Our models achieved promising performance, but also revealed important limitations and areas for improvement.
The most successful approaches combined transfer learning from natural images with domain-specific adaptations for histopathology. This suggests a path forward where general computer vision techniques are tailored to the specific characteristics and requirements of medical imaging.
While technical performance is important, the ultimate goal is to develop systems that can meaningfully improve cancer diagnosis and patient care. This requires not only advancing the models themselves but also addressing the broader clinical, operational, and ethical considerations of deploying AI in healthcare.
As deep learning and computational pathology continue to evolve, we can expect increasingly sophisticated and clinically valuable tools for histopathologic cancer detection, potentially transforming how cancer is diagnosed and treated.
8. References¶
B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, M. Welling. "Rotation Equivariant CNNs for Digital Pathology". arXiv:1806.03962
Ehteshami Bejnordi et al. Diagnostic Assessment of Deep Learning Algorithms for Detection of Lymph Node Metastases in Women With Breast Cancer. JAMA: The Journal of the American Medical Association, 318(22), 2199–2210. doi:jama.2017.14585
Will Cukierski. Histopathologic Cancer Detection. https://kaggle.com/competitions/histopathologic-cancer-detection, 2018. Kaggle.
Litjens, G., Kooi, T., Bejnordi, B. E., Setio, A. A. A., Ciompi, F., Ghafoorian, M., ... & Sánchez, C. I. (2017). A survey on deep learning in medical image analysis. Medical image analysis, 42, 60-88.
Campanella, G., Hanna, M. G., Geneslaw, L., Miraflor, A., Werneck Krauss Silva, V., Busam, K. J., ... & Fuchs, T. J. (2019). Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nature medicine, 25(8), 1301-1309.
Komura, D., & Ishikawa, S. (2018). Machine learning methods for histopathological image analysis. Computational and structural biotechnology journal, 16, 34-42.